{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 用于cifar10图像分类的5层卷积神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.导入模块,数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "# 导入模块\n",
    "import os\n",
    "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"]=\"3\"    #解决输出警告from ._conv import register_converters as _register_converters\n",
    "os.chdir('/home/david/tensorflow/卷积神经网络/cifar10')     #切换到cifar10目录\n",
    "import cifar10, cifar10_input\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 80\r\n",
      "-rwxrwxrwx 1 david david  1683 4月   4 14:08 BUILD\r\n",
      "drwxr-xr-x 2 david david  4096 6月   5  2009 cifar-10-batches-bin\r\n",
      "-rwxrwxrwx 1 david david  5458 4月   4 14:08 cifar10_eval.py\r\n",
      "-rwxrwxrwx 1 david david 10209 4月   4 14:08 cifar10_input.py\r\n",
      "-rwxrwxrwx 1 david david  2274 4月   4 14:08 cifar10_input_test.py\r\n",
      "-rwxrwxrwx 1 david david 10648 4月   4 14:08 cifar10_multi_gpu_train.py\r\n",
      "-rwxrwxrwx 1 david david 14675 4月   4 14:08 cifar10.py\r\n",
      "-rwxrwxrwx 1 david david  4491 4月   4 14:08 cifar10_train.py\r\n",
      "-rwxrwxrwx 1 david david   899 4月   4 14:08 __init__.py\r\n",
      "drwxrwxr-x 2 david david  4096 4月   5 22:16 __pycache__\r\n",
      "-rwxrwxrwx 1 david david   624 4月   4 14:08 README.md\r\n"
     ]
    }
   ],
   "source": [
    "# 下载数据集(已下载)\n",
    "# cifar10.maybe_download_and_extract() # 下载的位置为`/tmp/cifar10_data/cifar-10-batches-bin`\n",
    "!ls -l # cifar-10-batches-bin 是存放cifar10数据的文件夹"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.\n"
     ]
    }
   ],
   "source": [
    "# 数据预处理\n",
    "max_steps = 3000\n",
    "batch_size = 128\n",
    "data_dir = './cifar-10-batches-bin'\n",
    "# 获得数据增强,增广的训练集(左右翻转,随机裁剪,随机对比度,随机亮度及数据标准化)\n",
    "images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=batch_size)\n",
    "# 测试集只裁剪中间的24*24,并数据标准化处理\n",
    "images_test, labels_test = cifar10_input.inputs(eval_data=True, data_dir=data_dir, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.定义权重初始化函数及训练集数据占位符变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义输入数据及其标签\n",
    "image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])   # 输入的图像\n",
    "label_holder = tf.placeholder(tf.int32, [batch_size])              # 标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 权重初始化函数\n",
    "def variable_with_weight_loss(shape, stddev, w1):\n",
    "    \"\"\"\n",
    "    功能: 为权重进行初始化,并给权重添加一定的损失\n",
    "    参数: shape:权重向量的形状;stddev:标准差的大小;w1:控制权重的损失大小\n",
    "    返回: 初始化的权重向量\n",
    "    \"\"\"\n",
    "    var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))   # 用截断的正态分布初始化权重\n",
    "    if w1 is not None:  # 如果为权重添加l2损失\n",
    "        weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')\n",
    "        tf.add_to_collection('losses', weight_loss)          # 将weight loss添加到总的loss中\n",
    "    return var"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.定义网络的结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第一层卷积层[5,5,3,64]\n",
    "kernel1 = variable_with_weight_loss([5, 5, 3, 64], stddev=5e-2, w1=0.0)   # 不计算weight的loss\n",
    "conv_kernel1 = tf.nn.conv2d(image_holder, kernel1, [1, 1, 1, 1], padding='SAME')\n",
    "bias1 = tf.Variable(tf.constant(0.0, shape=[64]))\n",
    "conv1 = tf.nn.relu(conv_kernel1 + bias1)                                              # tf.nn.bias_add()功能类似,卷积           \n",
    "pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') # 最大池化\n",
    "norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)         # lrn局部响应均值化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第二层卷积层[5,5,64,64]\n",
    "kernel2 = variable_with_weight_loss([5, 5, 64, 64], stddev=5e-2, w1=0.0)\n",
    "conv_kernel2 = tf.nn.conv2d(norm1, kernel2, [1, 1, 1, 1], padding='SAME')\n",
    "bias2 = tf.Variable(tf.constant(0.0, shape=[64]))\n",
    "conv2 = tf.nn.relu(conv_kernel2 + bias2)\n",
    "norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)\n",
    "pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第三层全连接层[384]\n",
    "pool2_flat = tf.reshape(pool2, [batch_size, -1])   # 将feature map重塑成一个一维度向量\n",
    "dim = pool2_flat.get_shape()[1].value              # 获取向量的维度\n",
    "weight3 = variable_with_weight_loss([dim, 384], stddev=0.04, w1=0.004)  # 全连接层的权重,加入损失\n",
    "bias3 = tf.Variable(tf.constant(0.0, shape=[384])) # 偏置项\n",
    "fc3 = tf.nn.relu(tf.matmul(pool2_flat, weight3) + bias3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第四层全连接层 [192]\n",
    "weight4 = variable_with_weight_loss([384, 192], stddev=0.04, w1=0.004)   # 全连接层的权重,加入损失\n",
    "bias4 = tf.Variable(tf.constant(0.0, shape=[192]))\n",
    "fc4 = tf.nn.relu(tf.matmul(fc3, weight4) + bias4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第五层全连接层==>分类层,概率层\n",
    "weight5 = variable_with_weight_loss([192, 10], stddev=1/192.0, w1=0.0) # 标准差为节点的倒数\n",
    "bias5 = tf.Variable(tf.constant(0.0, shape=[10]))\n",
    "logits = tf.matmul(fc4, weight5) + bias5                               # 不需要relu非线性激活"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义总的损失函数并加入权重的l2 loss\n",
    "def loss(logits, labels):\n",
    "    \"\"\"\n",
    "    功能: 计算预测值与标签的损失+全连接层的权重损失\n",
    "    参数: logits:预测输出,这里为特征不是概率; labels:训练集标签\n",
    "    返回: 总的损失\n",
    "    \"\"\"\n",
    "    # 求batch_size中每个样本的损失\n",
    "    labels = tf.cast(labels, tf.int64)                 # 将int32的labels类型变为int64\n",
    "    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, \\\n",
    "                                                  labels=labels, name='cross_entropy_per_examples')\n",
    "    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')   # 对batch_size的样本损失求平均\n",
    "    tf.add_to_collection('losses', cross_entropy_mean)   # 将模型的输出损失添加到总的损失中\n",
    "    losses = tf.get_collection('losses')               # 获取所有的loss\n",
    "    total_loss = tf.add_n(losses, name='total_loss')   # 对所有loss求和,得到总的loss\n",
    "    return total_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.定义准确率函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算网络总的损失\n",
    "loss = loss(logits, label_holder)\n",
    "# 迭代优化的方法,1e-3为学习率\n",
    "train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)\n",
    "# top k准确率,默认为1\n",
    "top_k_op = tf.nn.in_top_k(logits, label_holder, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.创建会话,初始变量并训练网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<Thread(QueueRunnerThread-input_producer-input_producer/input_producer_EnqueueMany, started daemon 139786622859008)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786614466304)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786606073600)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786597680896)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786589288192)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786580895488)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786572502784)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786018879232)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786010486528)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139786002093824)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139785993701120)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139785985308416)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139785976915712)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139785968523008)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139785482008320)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139785473615616)>,\n",
       " <Thread(QueueRunnerThread-shuffle_batch/random_shuffle_queue-shuffle_batch/random_shuffle_queue_enqueue, started daemon 139785465222912)>,\n",
       " <Thread(QueueRunnerThread-input/input_producer-input/input_producer/input_producer_EnqueueMany, started daemon 139785456830208)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139785448437504)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139785440044800)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139785431652096)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784945137408)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784936744704)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784928352000)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784919959296)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784911566592)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784903173888)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784894781184)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784408266496)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784399873792)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784391481088)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784383088384)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784374695680)>,\n",
       " <Thread(QueueRunnerThread-batch/fifo_queue-batch/fifo_queue_enqueue, started daemon 139784366302976)>]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess = tf.InteractiveSession()            # 创建会话\n",
    "tf.global_variables_initializer().run()   # 全局变量初始化\n",
    "tf.train.start_queue_runners()            # 使用16个线程来加速数据增加图像预处理,否则后续训练无法运行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0,loss=6.31 (14.7 examples/sec; 8.698 sec/batch)\n",
      "step 10,loss=4.61 (1883.9 examples/sec; 0.068 sec/batch)\n",
      "step 20,loss=3.82 (1815.0 examples/sec; 0.071 sec/batch)\n",
      "step 30,loss=3.06 (1852.7 examples/sec; 0.069 sec/batch)\n",
      "step 40,loss=2.67 (1774.6 examples/sec; 0.072 sec/batch)\n",
      "step 50,loss=2.40 (1816.7 examples/sec; 0.070 sec/batch)\n",
      "step 60,loss=2.21 (1845.8 examples/sec; 0.069 sec/batch)\n",
      "step 70,loss=2.11 (1801.0 examples/sec; 0.071 sec/batch)\n",
      "step 80,loss=2.14 (1991.1 examples/sec; 0.064 sec/batch)\n",
      "step 90,loss=1.95 (1897.5 examples/sec; 0.067 sec/batch)\n",
      "step 100,loss=1.93 (1900.2 examples/sec; 0.067 sec/batch)\n",
      "step 110,loss=1.94 (1966.7 examples/sec; 0.065 sec/batch)\n",
      "step 120,loss=1.83 (1878.7 examples/sec; 0.068 sec/batch)\n",
      "step 130,loss=1.83 (1836.5 examples/sec; 0.070 sec/batch)\n",
      "step 140,loss=1.90 (1994.6 examples/sec; 0.064 sec/batch)\n",
      "step 150,loss=1.83 (1845.9 examples/sec; 0.069 sec/batch)\n",
      "step 160,loss=1.75 (1897.5 examples/sec; 0.067 sec/batch)\n",
      "step 170,loss=1.87 (1919.2 examples/sec; 0.067 sec/batch)\n",
      "step 180,loss=1.79 (1808.1 examples/sec; 0.071 sec/batch)\n",
      "step 190,loss=1.73 (1891.1 examples/sec; 0.068 sec/batch)\n",
      "step 200,loss=2.00 (1793.6 examples/sec; 0.071 sec/batch)\n",
      "step 210,loss=1.76 (2003.2 examples/sec; 0.064 sec/batch)\n",
      "step 220,loss=1.93 (1970.0 examples/sec; 0.065 sec/batch)\n",
      "step 230,loss=1.85 (1822.4 examples/sec; 0.070 sec/batch)\n",
      "step 240,loss=1.71 (1798.1 examples/sec; 0.071 sec/batch)\n",
      "step 250,loss=1.55 (1824.2 examples/sec; 0.070 sec/batch)\n",
      "step 260,loss=1.68 (1780.3 examples/sec; 0.072 sec/batch)\n",
      "step 270,loss=1.60 (1982.3 examples/sec; 0.065 sec/batch)\n",
      "step 280,loss=1.67 (1930.1 examples/sec; 0.066 sec/batch)\n",
      "step 290,loss=1.77 (1964.0 examples/sec; 0.065 sec/batch)\n",
      "step 300,loss=1.56 (1995.1 examples/sec; 0.064 sec/batch)\n",
      "step 310,loss=1.67 (1816.6 examples/sec; 0.070 sec/batch)\n",
      "step 320,loss=1.70 (2012.5 examples/sec; 0.064 sec/batch)\n",
      "step 330,loss=1.64 (1913.3 examples/sec; 0.067 sec/batch)\n",
      "step 340,loss=1.69 (1990.6 examples/sec; 0.064 sec/batch)\n",
      "step 350,loss=1.67 (1920.4 examples/sec; 0.067 sec/batch)\n",
      "step 360,loss=1.62 (1924.4 examples/sec; 0.067 sec/batch)\n",
      "step 370,loss=1.81 (1797.7 examples/sec; 0.071 sec/batch)\n",
      "step 380,loss=1.77 (1977.7 examples/sec; 0.065 sec/batch)\n",
      "step 390,loss=1.75 (1787.3 examples/sec; 0.072 sec/batch)\n",
      "step 400,loss=1.67 (1947.9 examples/sec; 0.066 sec/batch)\n",
      "step 410,loss=1.55 (1808.0 examples/sec; 0.071 sec/batch)\n",
      "step 420,loss=1.65 (1819.1 examples/sec; 0.070 sec/batch)\n",
      "step 430,loss=1.61 (1844.8 examples/sec; 0.069 sec/batch)\n",
      "step 440,loss=1.57 (1838.8 examples/sec; 0.070 sec/batch)\n",
      "step 450,loss=1.36 (1892.0 examples/sec; 0.068 sec/batch)\n",
      "step 460,loss=1.46 (1921.0 examples/sec; 0.067 sec/batch)\n",
      "step 470,loss=1.61 (1873.7 examples/sec; 0.068 sec/batch)\n",
      "step 480,loss=1.38 (1802.2 examples/sec; 0.071 sec/batch)\n",
      "step 490,loss=1.42 (1850.1 examples/sec; 0.069 sec/batch)\n",
      "step 500,loss=1.56 (1865.4 examples/sec; 0.069 sec/batch)\n",
      "step 510,loss=1.58 (1944.6 examples/sec; 0.066 sec/batch)\n",
      "step 520,loss=1.57 (1940.3 examples/sec; 0.066 sec/batch)\n",
      "step 530,loss=1.47 (1976.7 examples/sec; 0.065 sec/batch)\n",
      "step 540,loss=1.56 (1874.6 examples/sec; 0.068 sec/batch)\n",
      "step 550,loss=1.35 (1990.6 examples/sec; 0.064 sec/batch)\n",
      "step 560,loss=1.46 (1874.4 examples/sec; 0.068 sec/batch)\n",
      "step 570,loss=1.42 (1725.0 examples/sec; 0.074 sec/batch)\n",
      "step 580,loss=1.54 (1905.0 examples/sec; 0.067 sec/batch)\n",
      "step 590,loss=1.42 (1957.8 examples/sec; 0.065 sec/batch)\n",
      "step 600,loss=1.62 (2087.9 examples/sec; 0.061 sec/batch)\n",
      "step 610,loss=1.52 (2004.7 examples/sec; 0.064 sec/batch)\n",
      "step 620,loss=1.47 (2015.1 examples/sec; 0.064 sec/batch)\n",
      "step 630,loss=1.47 (1865.1 examples/sec; 0.069 sec/batch)\n",
      "step 640,loss=1.48 (1889.6 examples/sec; 0.068 sec/batch)\n",
      "step 650,loss=1.58 (1773.3 examples/sec; 0.072 sec/batch)\n",
      "step 660,loss=1.43 (1971.6 examples/sec; 0.065 sec/batch)\n",
      "step 670,loss=1.37 (1886.3 examples/sec; 0.068 sec/batch)\n",
      "step 680,loss=1.40 (1885.2 examples/sec; 0.068 sec/batch)\n",
      "step 690,loss=1.32 (1897.8 examples/sec; 0.067 sec/batch)\n",
      "step 700,loss=1.45 (1893.2 examples/sec; 0.068 sec/batch)\n",
      "step 710,loss=1.47 (1891.0 examples/sec; 0.068 sec/batch)\n",
      "step 720,loss=1.37 (1883.2 examples/sec; 0.068 sec/batch)\n",
      "step 730,loss=1.37 (1865.4 examples/sec; 0.069 sec/batch)\n",
      "step 740,loss=1.45 (1986.1 examples/sec; 0.064 sec/batch)\n",
      "step 750,loss=1.26 (1844.3 examples/sec; 0.069 sec/batch)\n",
      "step 760,loss=1.33 (1977.8 examples/sec; 0.065 sec/batch)\n",
      "step 770,loss=1.30 (1963.6 examples/sec; 0.065 sec/batch)\n",
      "step 780,loss=1.47 (1818.1 examples/sec; 0.070 sec/batch)\n",
      "step 790,loss=1.32 (1960.1 examples/sec; 0.065 sec/batch)\n",
      "step 800,loss=1.36 (1925.6 examples/sec; 0.066 sec/batch)\n",
      "step 810,loss=1.45 (2012.2 examples/sec; 0.064 sec/batch)\n",
      "step 820,loss=1.48 (2055.3 examples/sec; 0.062 sec/batch)\n",
      "step 830,loss=1.39 (1897.6 examples/sec; 0.067 sec/batch)\n",
      "step 840,loss=1.49 (2005.4 examples/sec; 0.064 sec/batch)\n",
      "step 850,loss=1.32 (1774.7 examples/sec; 0.072 sec/batch)\n",
      "step 860,loss=1.39 (1908.6 examples/sec; 0.067 sec/batch)\n",
      "step 870,loss=1.32 (1747.0 examples/sec; 0.073 sec/batch)\n",
      "step 880,loss=1.42 (1968.3 examples/sec; 0.065 sec/batch)\n",
      "step 890,loss=1.37 (1890.7 examples/sec; 0.068 sec/batch)\n",
      "step 900,loss=1.45 (1820.3 examples/sec; 0.070 sec/batch)\n",
      "step 910,loss=1.32 (1875.5 examples/sec; 0.068 sec/batch)\n",
      "step 920,loss=1.43 (1868.9 examples/sec; 0.068 sec/batch)\n",
      "step 930,loss=1.37 (1948.7 examples/sec; 0.066 sec/batch)\n",
      "step 940,loss=1.51 (1842.7 examples/sec; 0.069 sec/batch)\n",
      "step 950,loss=1.31 (1931.5 examples/sec; 0.066 sec/batch)\n",
      "step 960,loss=1.49 (1705.4 examples/sec; 0.075 sec/batch)\n",
      "step 970,loss=1.28 (1720.1 examples/sec; 0.074 sec/batch)\n",
      "step 980,loss=1.41 (1965.8 examples/sec; 0.065 sec/batch)\n",
      "step 990,loss=1.42 (1857.0 examples/sec; 0.069 sec/batch)\n"
     ]
    }
   ],
   "source": [
    "# 开始训练网络\n",
    "for step in range(1000):\n",
    "    start_time = time.time()  # 开始时间\n",
    "    images_batch, label_batch = sess.run([images_train, labels_train])   #获取batch_size的训练数据\n",
    "    # 开始训练\n",
    "    _, loss_value = sess.run([train_op, loss], \\\n",
    "                             feed_dict={image_holder: images_batch, label_holder: label_batch})\n",
    "    duration = time.time() - start_time    # 训练完一个batch_size的时间\n",
    "    if step % 10 == 0:   # 迭代10次\n",
    "        examples_per_sec = batch_size / duration  # 每s迭代的样本数\n",
    "        sec_per_batch = float(duration)\n",
    "        train_info = 'step %d,loss=%.2f (%.1f examples/sec; %.3f sec/batch)'\n",
    "        print(train_info % (step, loss_value, examples_per_sec, sec_per_batch))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.测试模型准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "collect number: 109\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 111\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 104\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 105\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 108\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 99\n",
      "---------------------------------------\n",
      "collect number: 95\n",
      "---------------------------------------\n",
      "collect number: 100\n",
      "---------------------------------------\n",
      "collect number: 89\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 102\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 96\n",
      "---------------------------------------\n",
      "collect number: 105\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 107\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 107\n",
      "---------------------------------------\n",
      "collect number: 110\n",
      "---------------------------------------\n",
      "collect number: 94\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 106\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 102\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 100\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 106\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 92\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 97\n",
      "---------------------------------------\n",
      "collect number: 105\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 104\n",
      "---------------------------------------\n",
      "collect number: 97\n",
      "---------------------------------------\n",
      "collect number: 93\n",
      "---------------------------------------\n",
      "collect number: 88\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 102\n",
      "---------------------------------------\n",
      "collect number: 104\n",
      "---------------------------------------\n",
      "collect number: 104\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 104\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 99\n",
      "---------------------------------------\n",
      "collect number: 106\n",
      "---------------------------------------\n",
      "collect number: 94\n",
      "---------------------------------------\n",
      "collect number: 97\n",
      "---------------------------------------\n",
      "collect number: 103\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 100\n",
      "---------------------------------------\n",
      "collect number: 97\n",
      "---------------------------------------\n",
      "collect number: 107\n",
      "---------------------------------------\n",
      "collect number: 111\n",
      "---------------------------------------\n",
      "collect number: 96\n",
      "---------------------------------------\n",
      "collect number: 93\n",
      "---------------------------------------\n",
      "collect number: 102\n",
      "---------------------------------------\n",
      "collect number: 102\n",
      "---------------------------------------\n",
      "collect number: 104\n",
      "---------------------------------------\n",
      "collect number: 101\n",
      "---------------------------------------\n",
      "collect number: 98\n",
      "---------------------------------------\n",
      "collect number: 97\n",
      "---------------------------------------\n",
      "collect number: 100\n",
      "---------------------------------------\n",
      "collect number: 99\n",
      "---------------------------------------\n",
      "collect number: 115\n",
      "---------------------------------------\n",
      "collect number: 104\n",
      "---------------------------------------\n",
      "collect number: 97\n",
      "---------------------------------------\n",
      "collect number: 107\n",
      "---------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# 用测试集测试模型准确率\n",
    "num_examples = 10000\n",
    "import math\n",
    "num_iter = int(math.ceil(num_examples / batch_size))   # ceil函数的作用是将一个小数取比它大的整数,如:30.1取31\n",
    "true_count = 0\n",
    "total_example_count = num_iter * batch_size\n",
    "step = 0\n",
    "while step < num_iter:\n",
    "    image_batch, label_batch = sess.run([images_test, labels_test])   # 获取batch_size的测试数据\n",
    "    # 开始测试\n",
    "    predictions = sess.run([top_k_op], \\\n",
    "                          feed_dict={image_holder: image_batch, label_holder: label_batch})\n",
    "#     print(predictions)\n",
    "    print('collect number:', np.sum(predictions))  \n",
    "    true_count += np.sum(predictions)              \n",
    "    step += 1\n",
    "    print('---------------------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision @ 1 = 0.790\n"
     ]
    }
   ],
   "source": [
    "precision = true_count / total_example_count\n",
    "print('precision @ 1 = %.3f' % precision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
